# -*- coding: utf-8 -*-  
'''
初始化百度数据集相关工作

@author: luoyi
Created on 2021年6月23日
'''
import data.dataset_baidu as ds_baidu
from data.dataset_tfrecord_baidu import CRFLinkerTFRecordWriter

import utils.conf as conf
import utils.relationships as rel
import utils.dictionaries as dicts


rel.load_rel_id_from_pkl()
dicts.load_dict_from_pkl()


#    初始化关系数据
print('初始化关系数据')
ds_baidu.save_rel_id_to_pkl(file_path=conf.DATASET_BAIDU.get_schemas_path(), 
                            rel_id_pkl_path=conf.DATASET_BAIDU.get_rel_id_path(), 
                            id_rel_pkl_path=conf.DATASET_BAIDU.get_id_rel_path())
print('初始化关系数据完成.')

#    写入tfrecord文件
train_writer = CRFLinkerTFRecordWriter(original_file_path=conf.DATASET_BAIDU.get_train_data_path(), 
                                       out_file_path=conf.DATASET_BAIDU.get_train_crflinker_dataset_path(), 
                                       max_sen_len=conf.TPLINKER.get_max_sentence_len(), 
                                       rel_size=len(rel.id_rel),
                                       record_count=conf.DATASET_BAIDU.get_record_count(),
                                      )
print('写入训练数据. tfrecord模式')
train_writer.save_to_tfrecord(show_cross=False)
print('写入训练数据. tfrecord模式。 完成.')
#    crflinker数据 
#    train 最大长度：128，写入文件总数：163840
#    val 最大长度：128，写入文件总数：16384


#    写入tfrecord文件
print('写入验证数据. tfrecord模式')
val_writer = CRFLinkerTFRecordWriter(original_file_path=conf.DATASET_BAIDU.get_val_data_path(), 
                                     out_file_path=conf.DATASET_BAIDU.get_val_crflinker_dataset_path(), 
                                     max_sen_len=conf.TPLINKER.get_max_sentence_len(), 
                                     rel_size=len(rel.id_rel),
                                     record_count=conf.DATASET_BAIDU.get_record_count())
val_writer.save_to_tfrecord(show_cross=False)
print('写入验证数据. tfrecord模式。完成.')
